# SPDX-License-Identifier: Apache-2.0
import sys
from pathlib import Path
sys.path.append(str(Path(__file__).absolute().parent.parent))
print(sys.path)
import numpy as np
import argparse
import json
from collections import OrderedDict, defaultdict
from utils.metrics import METRICS
from utils.tabelle import *



def create_tables(dataset_methods_data, method_order, ignore_nan):
    num_datasets = len(dataset_methods_data)
    num_methods =  0
    for dsname, methods in dataset_methods_data.items():
        num_methods = max(len(methods), num_methods)
    assert num_methods == len(method_order), f'{num_methods} vs {len(method_order)}'
    print(f'num_datasets = {num_datasets}')
    print(f'num_methods = {num_methods}')


    dataset_object_gtimages = {}
    for dsname, method_data in dataset_methods_data.items():
        all_gt_image_names = set()
        for _, data in method_data.items():
            all_gt_image_names.update(data.keys())

        object_gtimages = defaultdict(list)
        for x in all_gt_image_names:
            object_gtimages[Path(x).parent.name].append(x)
        dataset_object_gtimages[dsname] = object_gtimages


    num_col_blocks = num_datasets + 1 # we have a column block for each dataset and 'mean'

    table = Table((num_methods+2, 1+num_col_blocks*len(METRICS)))
    print(f'table shape is {table.shape}')

    table[0,0].rowfmt.topmost_line = True
    table[0,0].colfmt.align = 'l'
    table[1,0] = Cell('Method', bold=True)
    table[0,0].rowfmt.line = [(1,3), (4,6), (7,9), (10,12), (13,15)]
    table[1,0].rowfmt.line = True
    for i, dsname in enumerate(dataset_methods_data):
        table[0,1+i*len(METRICS)] = Cell(dsname, col_span=len(METRICS), bold=True)
    table[0,1+(i+1)*len(METRICS)] = Cell('Mean', col_span=3, bold=True)
    table[1,1:] = num_col_blocks*[x['latex'] for _, x in METRICS.items()]
    table[-1,0].rowfmt.line = True
    for i in range(num_col_blocks):
        for metric_i, (metric_name, metric) in enumerate(METRICS.items()):
            table[0,len(METRICS)*i+metric_i+1].colfmt.auto_highlight = metric['best']
            table[0,len(METRICS)*i+metric_i+1].colfmt.num_format = '{:2.3f}' if metric_name == 'LPIPS' else '{:2.2f}'

    for ds_i, (dsname, method_data) in enumerate(dataset_methods_data.items()):
        col_offset = 1 + ds_i*len(METRICS)
        object_gtimages = dataset_object_gtimages[dsname]
        for method_i, method_name in enumerate(method_order):
            row = method_i+2
            table[row,0] = method_name
            data = method_data.get(method_name)
            if not data:
                continue
            
            for metric_i, metric in enumerate(METRICS):
                values = defaultdict(list)
                for objname, gtimages in object_gtimages.items():
                    for x in gtimages:
                        values[objname].append(data.get(x,{}).get(metric, np.nan))
                        if ignore_nan and np.isnan(values[objname][-1]):
                            print(f'Row "{table[row,0].value}" contains nan values!')

                col = col_offset + metric_i
                if ignore_nan:
                    table[row,col] = np.mean( [np.nanmean(v) for v in values.values()] )
                else:
                    table[row,col] = np.mean( [np.mean(v) for v in values.values()] )

    # compute mean columns
    numpy_obj_table = table.numpy()
    for method_i, method_name in enumerate(method_order):
        row = method_i+2
        for metric_i, metric in enumerate(METRICS):
            values = []
            for ds_i, (dsname, method_data) in enumerate(dataset_methods_data.items()):
                col_offset = 1 + ds_i*len(METRICS)
                col = col_offset + metric_i
                cell = numpy_obj_table[row,col]
                if cell is not None:
                    values.append(float(cell))
                else:
                    values.append(np.nan)
            table[row, 1+len(METRICS)*num_datasets+metric_i] = np.mean(values)
    

    print(table)

    print('\n--- Latex\n')
    latex_str = table.latex()
    print(latex_str)
    print('\n---\n')
    return latex_str


if __name__ == '__main__':
    parser = argparse.ArgumentParser(
        description="Script that creates latex tables from the json files generated by the evaluate.py script.",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    parser.add_argument('--dtu', type=Path, nargs='+', help="Path to the eval files for the DTU dataset")
    parser.add_argument('--bmvs', type=Path, nargs='+', help="Path to the eval files for the Blended MVS dataset")
    parser.add_argument('--ord_nvs', type=Path, nargs='+', help="Path to the eval files for the Object Relighting Dataset for novel view synthesis")
    parser.add_argument('--synthetic4relight_nvs', type=Path, nargs='+', help="Path to the eval files for the Object Relighting Dataset for novel view synthesis")
    parser.add_argument('--method_order', type=str, nargs='+', help="List of strings to define the order of the method in the table")
    parser.add_argument('--output', type=Path, help="Path to the output tex file")
    parser.add_argument('--ignore_nan', action='store_true', help="If set nan values will be ignored. This can be used if some results are missing.")

    if len(sys.argv)==1:
        parser.print_help(sys.stderr)
        sys.exit(1)
    args = parser.parse_args()
    print(args)

    method_names = set()

    dataset_methods_data = OrderedDict()

    ds_name = OrderedDict([
        ('dtu', 'DTU'),
        ('bmvs', 'BMVS'),
        ('synthetic4relight_nvs', 'Synthetic4Relight'),
        ('ord_nvs', 'Object Relighting Dataset'),
    ])

    for ds, dsname in ds_name.items():
        if getattr(args, ds):
            methods_data = {}
            for p in getattr(args,ds):
                with open(p,'r') as f:
                    method = json.load(f)
                    assert 'name' in method, 'Each method must have a unique name because it is used as a key!'
                    name = method['name']
                    method_names.add(name)
                    assert name not in methods_data, f'There is a duplicate for key "{name}"!'
                    methods_data[name] = method['data']
            dataset_methods_data[dsname] = methods_data


    method_order = OrderedDict()
    if args.method_order:
        from thefuzz import process
        for x in args.method_order:
            closest = process.extractOne(x, method_names)
            print(f'{x} -> {closest[0]}')
            method_order[x] = closest[0]
        assert len(set(method_order.values())) == len(args.method_order), 'cannot determine method order from cmdline arg'
        method_order = list(method_order.values())
    else:
        method_order = sorted(list(method_names))
        
    print('method order is ')
    for x in method_order:
        print(' ',x)

    latex = create_tables(dataset_methods_data=dataset_methods_data, method_order=method_order, ignore_nan=args.ignore_nan)
    if args.output:
        with open(args.output, 'w') as f:
            f.write(latex)